{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "16de7336",
   "metadata": {},
   "source": [
    "# Tagging and Extraction Using OpenAI functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb41f5f4-df8d-4d04-9eaa-193b8c29b00b",
   "metadata": {
    "height": 115,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import openai\n",
    "\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "_ = load_dotenv(find_dotenv()) # read local .env file\n",
    "openai.api_key = os.environ['OPENAI_API_KEY']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfb2ede9-b0b4-4c8c-b00f-9a9911f38614",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "from typing import List\n",
    "from pydantic import BaseModel, Field\n",
    "from langchain.utils.openai_functions import convert_pydantic_to_openai_function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3ae5481-4155-4831-a5dd-4c183b3b4990",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "class Tagging(BaseModel):\n",
    "    \"\"\"Tag the piece of text with particular info.\"\"\"\n",
    "    sentiment: str = Field(description=\"sentiment of text, should be `pos`, `neg`, or `neutral`\")\n",
    "    language: str = Field(description=\"language of text (should be ISO 639-1 code)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cd656b2-62be-49d4-a277-b920c67203c1",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "convert_pydantic_to_openai_function(Tagging)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed8b15a7-450c-43d6-af44-ca63800a4619",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "from langchain.prompts import ChatPromptTemplate\n",
    "from langchain.chat_models import ChatOpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cc5bac3-3783-4ce7-9de6-83c905fba7c5",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model = ChatOpenAI(temperature=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "798ac5b3-7e47-4cfb-8173-63900cf1e7f6",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "tagging_functions = [convert_pydantic_to_openai_function(Tagging)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d558031-18cf-4791-b061-8911ce314605",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "prompt = ChatPromptTemplate.from_messages([\n",
    "    (\"system\", \"Think carefully, and then tag the text as instructed\"),\n",
    "    (\"user\", \"{input}\")\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00607aff-a64e-42b7-8cf4-893e8393dd33",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "model_with_functions = model.bind(\n",
    "    functions=tagging_functions,\n",
    "    function_call={\"name\": \"Tagging\"}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d0bc519-2b8e-40d0-88ec-fc5ef0291f16",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "tagging_chain = prompt | model_with_functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8688fdba-0996-446c-a77b-318094944998",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "tagging_chain.invoke({\"input\": \"I love langchain\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7df9aac5-7194-4a12-bda7-e7e83e705929",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "tagging_chain.invoke({\"input\": \"non mi piace questo cibo\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0749e408-3878-4ddd-ad33-8d5b8418a3f2",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc45243e-7225-427d-b679-d737ebaec780",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "tagging_chain = prompt | model_with_functions | JsonOutputFunctionsParser()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ba84e64-518c-40e1-a7c9-d174314bc3fb",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "tagging_chain.invoke({\"input\": \"non mi piace questo cibo\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92e417b4-9306-413f-865f-e13dbd2d0196",
   "metadata": {},
   "source": [
    "## Extraction\n",
    "\n",
    "Extraction is similar to tagging, but used for extracting multiple pieces of information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97338dce-a61a-4f8b-912c-6108dcb86183",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "class Person(BaseModel):\n",
    "    \"\"\"Information about a person.\"\"\"\n",
    "    name: str = Field(description=\"person's name\")\n",
    "    age: Optional[int] = Field(description=\"person's age\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f558fb45-f3a5-47be-8778-dddec2d00ba1",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "class Information(BaseModel):\n",
    "    \"\"\"Information to extract.\"\"\"\n",
    "    people: List[Person] = Field(description=\"List of info about people\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eeb6028-0ade-46f6-a899-ca10f185bb24",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "convert_pydantic_to_openai_function(Information)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f16ce50e-bad7-4fbb-b9e2-0adae6fc55f2",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "extraction_functions = [convert_pydantic_to_openai_function(Information)]\n",
    "extraction_model = model.bind(functions=extraction_functions, function_call={\"name\": \"Information\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ed58069-3f7c-433e-93ed-5660d21435c9",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "extraction_model.invoke(\"Joe is 30, his mom is Martha\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "172df097-250a-4813-b76d-21436c13056d",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "prompt = ChatPromptTemplate.from_messages([\n",
    "    (\"system\", \"Extract the relevant information, if not explicitly provided do not guess. Extract partial info\"),\n",
    "    (\"human\", \"{input}\")\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfec48d7-25a2-46d7-a7a2-33284b577dd3",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "extraction_chain = prompt | extraction_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f19622af-4b01-4145-ae5a-af15b551a182",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "extraction_chain.invoke({\"input\": \"Joe is 30, his mom is Martha\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38989431-9a61-461e-a86f-61a7dddab63c",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "extraction_chain = prompt | extraction_model | JsonOutputFunctionsParser()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13ac89c9-1e8d-4e11-8bca-909a6c34d23c",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "extraction_chain.invoke({\"input\": \"Joe is 30, his mom is Martha\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8db4b39-c28b-4c0e-8b9a-55ceb8ba817e",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2901d3ac-4dfc-4fb9-afd3-ab056489d3ae",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "extraction_chain = prompt | extraction_model | JsonKeyOutputFunctionsParser(key_name=\"people\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7eff400-9f3a-42a8-ba4a-a3757e7939e2",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "extraction_chain.invoke({\"input\": \"Joe is 30, his mom is Martha\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b8524f3-a663-4007-ad59-52dfe00343e9",
   "metadata": {},
   "source": [
    "## Doing it for real\n",
    "\n",
    "We can apply tagging to a larger body of text.\n",
    "\n",
    "For example, let's load this blog post and extract tag information from a sub-set of the text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc7c4a8c-d4c8-4400-b1d3-5c7c0c25786e",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "from langchain.document_loaders import WebBaseLoader\n",
    "loader = WebBaseLoader(\"https://lilianweng.github.io/posts/2023-06-23-agent/\")\n",
    "documents = loader.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbd2e910-e0e0-447a-8118-6fe792f04e15",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "doc = documents[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5e39f78-ff18-4016-a751-7d0e4c82bb36",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "page_content = doc.page_content[:10000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb8d98ee-e9b2-4ce8-982a-c06dc006da76",
   "metadata": {
    "height": 29,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(page_content[:1000])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9db0d107-5284-4253-b769-dfedb97ce95d",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "class Overview(BaseModel):\n",
    "    \"\"\"Overview of a section of text.\"\"\"\n",
    "    summary: str = Field(description=\"Provide a concise summary of the content.\")\n",
    "    language: str = Field(description=\"Provide the language that the content is written in.\")\n",
    "    keywords: str = Field(description=\"Provide keywords related to the content.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bfe7fe6-a86d-416c-a7f2-373dc70fcba5",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "overview_tagging_function = [\n",
    "    convert_pydantic_to_openai_function(Overview)\n",
    "]\n",
    "tagging_model = model.bind(\n",
    "    functions=overview_tagging_function,\n",
    "    function_call={\"name\":\"Overview\"}\n",
    ")\n",
    "tagging_chain = prompt | tagging_model | JsonOutputFunctionsParser()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "518f2337-a8eb-4eff-b764-b70e48545d19",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "tagging_chain.invoke({\"input\": page_content})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5be5d604-3d95-476a-afe4-b46b021534fc",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "class Paper(BaseModel):\n",
    "    \"\"\"Information about papers mentioned.\"\"\"\n",
    "    title: str\n",
    "    author: Optional[str]\n",
    "\n",
    "\n",
    "class Info(BaseModel):\n",
    "    \"\"\"Information to extract\"\"\"\n",
    "    papers: List[Paper]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e56ca98-d05d-4199-a6ec-fe485b630da6",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "paper_extraction_function = [\n",
    "    convert_pydantic_to_openai_function(Info)\n",
    "]\n",
    "extraction_model = model.bind(\n",
    "    functions=paper_extraction_function, \n",
    "    function_call={\"name\":\"Info\"}\n",
    ")\n",
    "extraction_chain = prompt | extraction_model | JsonKeyOutputFunctionsParser(key_name=\"papers\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f082b7a0-b7ea-488a-923c-aba8e4b185a0",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "extraction_chain.invoke({\"input\": page_content})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a8dcce7-7032-49a8-893d-1659e2f51f03",
   "metadata": {
    "height": 199
   },
   "outputs": [],
   "source": [
    "template = \"\"\"A article will be passed to you. Extract from it all papers that are mentioned by this article. \n",
    "\n",
    "Do not extract the name of the article itself. If no papers are mentioned that's fine - you don't need to extract any! Just return an empty list.\n",
    "\n",
    "Do not make up or guess ANY extra information. Only extract what exactly is in the text.\"\"\"\n",
    "\n",
    "prompt = ChatPromptTemplate.from_messages([\n",
    "    (\"system\", template),\n",
    "    (\"human\", \"{input}\")\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5f4c8cc-d9a4-4556-a89e-94ec981c6f5e",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "extraction_chain = prompt | extraction_model | JsonKeyOutputFunctionsParser(key_name=\"papers\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b63e00b6-0f06-4e74-948e-04805e66f40c",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "extraction_chain.invoke({\"input\": page_content})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "365bfa5b-f0d2-4a19-8db1-1b4a0f51703d",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "extraction_chain.invoke({\"input\": \"hi\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a3e18b5-0692-49dc-b589-5d4ae5a43fdd",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "text_splitter = RecursiveCharacterTextSplitter(chunk_overlap=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a851b07c-0f73-4cf1-801f-9832190db93a",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "splits = text_splitter.split_text(doc.page_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbd6fd3c-8fca-4f16-a256-a66577b48eb1",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "len(splits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f1309bf-16a5-43ef-aa8e-427cf5120938",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "def flatten(matrix):\n",
    "    flat_list = []\n",
    "    for row in matrix:\n",
    "        flat_list += row\n",
    "    return flat_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5af35566-8f5f-4e02-86d2-c1a97ca4e573",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "flatten([[1, 2], [3, 4]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b77be284-fca5-4ab0-ab36-b83b3ab53136",
   "metadata": {
    "height": 29,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(splits[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67bded5d-de85-488c-bd3f-d83acbcfea33",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from langchain.schema.runnable import RunnableLambda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04233d0a-fe44-45ed-9145-af3caaa86863",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "prep = RunnableLambda(\n",
    "    lambda x: [{\"input\": doc} for doc in text_splitter.split_text(x)]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa8c2f1e-f3ac-4d9c-8ded-541778a32ef9",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "prep.invoke(\"hi\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f7d0255-4b58-42c9-a747-44768a866633",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "chain = prep | extraction_chain.map() | flatten"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dfa1f67-7246-4f4f-8967-4d7266df7b7a",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "chain.invoke(doc.page_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68cb460a-2b56-42be-acf6-4f61e40027f4",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32277f95-fdb1-4873-b513-4aa9f7437308",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbf4b7b3-13f8-4cb6-b658-a0b0af88034d",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23fe8752-104b-4672-a024-8ef182bde883",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "918dd6d5-fec4-4d08-be91-6b0bb56340ae",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caf6b3ac-0367-4bf9-b01b-acebd6d1cfbf",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2d90c83-cc84-4d44-bfff-35a30fb72f63",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b31bb740-2d2a-4fc8-8887-4e9da113e82d",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5409db4a-a52b-4943-944e-7221eda0339e",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cdcc19b-7b03-4d7e-a90d-1f8c33d1f9d9",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60c5f869-67cc-4b66-8d14-7d19dd01b22a",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
